{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Customize the client\n",
    "\n",
    "Welcome to the fourth part of the Flower federated learning tutorial. In the previous parts of this tutorial, we introduced federated learning with PyTorch and Flower ([part 1](https://flower.ai/docs/framework/tutorial-get-started-with-flower-pytorch.html)), we learned how strategies can be used to customize the execution on both the server and the clients ([part 2](https://flower.ai/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html)), and we built our own custom strategy from scratch ([part 3](https://flower.ai/docs/framework/tutorial-build-a-strategy-from-scratch-pytorch.html)).\n",
    "\n",
    "In this notebook, we revisit `NumPyClient` and introduce a new baseclass for building clients, simply named `Client`. In previous parts of this tutorial, we've based our client on `NumPyClient`, a convenience class which makes it easy to work with machine learning libraries that have good NumPy interoperability. With `Client`, we gain a lot of flexibility that we didn't have before, but we'll also have to do a few things the we didn't have to do before.\n",
    "\n",
    "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Flower Discuss and the Flower Slack to connect, ask questions, and get help:\n",
    "> - [Join Flower Discuss](https://discuss.flower.ai/) We'd love to hear from you in the `Introduction` topic! If anything is unclear, post in `Flower Help - Beginners`.\n",
    "> - [Join Flower Slack](https://flower.ai/join-slack) We'd love to hear from you in the `#introductions` channel! If anything is unclear, head over to the `#questions` channel.\n",
    "\n",
    "Let's go deeper and see what it takes to move from `NumPyClient` to `Client`! 🌼"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 0: Preparation\n",
    "\n",
    "Before we begin with the actual code, let's make sure that we have everything we need."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Installing dependencies\n",
    "\n",
    "First, we install the necessary packages:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q flwr[simulation] flwr-datasets[vision] torch torchvision scipy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have all dependencies installed, we can import everything we need for this tutorial:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import OrderedDict\n",
    "from typing import List\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import flwr\n",
    "from flwr.client import Client, ClientApp, NumPyClient\n",
    "from flwr.common import Context\n",
    "from flwr.server import ServerApp, ServerConfig, ServerAppComponents\n",
    "from flwr.simulation import run_simulation\n",
    "from flwr_datasets import FederatedDataset\n",
    "\n",
    "DEVICE = torch.device(\"cpu\")  # Try \"cuda\" to train on GPU\n",
    "print(f\"Training on {DEVICE}\")\n",
    "print(f\"Flower {flwr.__version__} / PyTorch {torch.__version__}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It is possible to switch to a runtime that has GPU acceleration enabled (on Google Colab: `Runtime > Change runtime type > Hardware acclerator: GPU > Save`). Note, however, that Google Colab is not always able to offer GPU acceleration. If you see an error related to GPU availability in one of the following sections, consider switching back to CPU-based execution by setting `DEVICE = torch.device(\"cpu\")`. If the runtime has GPU acceleration enabled, you should see the output `Training on cuda`, otherwise it'll say `Training on cpu`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data loading\n",
    "\n",
    "Let's now define a loading function for the CIFAR-10 training and test set, partition them into `num_partitions` smaller datasets (each split into training and validation set), and wrap everything in their own `DataLoader`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_datasets(partition_id: int, num_partitions: int):\n",
    "    fds = FederatedDataset(dataset=\"cifar10\", partitioners={\"train\": num_partitions})\n",
    "    partition = fds.load_partition(partition_id)\n",
    "    # Divide data on each node: 80% train, 20% test\n",
    "    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)\n",
    "    pytorch_transforms = transforms.Compose(\n",
    "        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n",
    "    )\n",
    "\n",
    "    def apply_transforms(batch):\n",
    "        # Instead of passing transforms to CIFAR10(..., transform=transform)\n",
    "        # we will use this function to dataset.with_transform(apply_transforms)\n",
    "        # The transforms object is exactly the same\n",
    "        batch[\"img\"] = [pytorch_transforms(img) for img in batch[\"img\"]]\n",
    "        return batch\n",
    "\n",
    "    partition_train_test = partition_train_test.with_transform(apply_transforms)\n",
    "    trainloader = DataLoader(partition_train_test[\"train\"], batch_size=32, shuffle=True)\n",
    "    valloader = DataLoader(partition_train_test[\"test\"], batch_size=32)\n",
    "    testset = fds.load_split(\"test\").with_transform(apply_transforms)\n",
    "    testloader = DataLoader(testset, batch_size=32)\n",
    "    return trainloader, valloader, testloader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model training/evaluation\n",
    "\n",
    "Let's continue with the usual model definition (including `set_parameters` and `get_parameters`), training and test functions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self) -> None:\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 6, 5)\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
    "        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = x.view(-1, 16 * 5 * 5)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "def get_parameters(net) -> List[np.ndarray]:\n",
    "    return [val.cpu().numpy() for _, val in net.state_dict().items()]\n",
    "\n",
    "\n",
    "def set_parameters(net, parameters: List[np.ndarray]):\n",
    "    params_dict = zip(net.state_dict().keys(), parameters)\n",
    "    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})\n",
    "    net.load_state_dict(state_dict, strict=True)\n",
    "\n",
    "\n",
    "def train(net, trainloader, epochs: int):\n",
    "    \"\"\"Train the network on the training set.\"\"\"\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.Adam(net.parameters())\n",
    "    net.train()\n",
    "    for epoch in range(epochs):\n",
    "        correct, total, epoch_loss = 0, 0, 0.0\n",
    "        for batch in trainloader:\n",
    "            images, labels = batch[\"img\"], batch[\"label\"]\n",
    "            images, labels = images.to(DEVICE), labels.to(DEVICE)\n",
    "            optimizer.zero_grad()\n",
    "            outputs = net(images)\n",
    "            loss = criterion(net(images), labels)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            # Metrics\n",
    "            epoch_loss += loss\n",
    "            total += labels.size(0)\n",
    "            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()\n",
    "        epoch_loss /= len(trainloader.dataset)\n",
    "        epoch_acc = correct / total\n",
    "        print(f\"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}\")\n",
    "\n",
    "\n",
    "def test(net, testloader):\n",
    "    \"\"\"Evaluate the network on the entire test set.\"\"\"\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    correct, total, loss = 0, 0, 0.0\n",
    "    net.eval()\n",
    "    with torch.no_grad():\n",
    "        for batch in testloader:\n",
    "            images, labels = batch[\"img\"], batch[\"label\"]\n",
    "            images, labels = images.to(DEVICE), labels.to(DEVICE)\n",
    "            outputs = net(images)\n",
    "            loss += criterion(outputs, labels).item()\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "    loss /= len(testloader.dataset)\n",
    "    accuracy = correct / total\n",
    "    return loss, accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Revisiting NumPyClient\n",
    "\n",
    "So far, we've implemented our client by subclassing `flwr.client.NumPyClient`. The three methods we implemented are `get_parameters`, `fit`, and `evaluate`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FlowerNumPyClient(NumPyClient):\n",
    "    def __init__(self, partition_id, net, trainloader, valloader):\n",
    "        self.partition_id = partition_id\n",
    "        self.net = net\n",
    "        self.trainloader = trainloader\n",
    "        self.valloader = valloader\n",
    "\n",
    "    def get_parameters(self, config):\n",
    "        print(f\"[Client {self.partition_id}] get_parameters\")\n",
    "        return get_parameters(self.net)\n",
    "\n",
    "    def fit(self, parameters, config):\n",
    "        print(f\"[Client {self.partition_id}] fit, config: {config}\")\n",
    "        set_parameters(self.net, parameters)\n",
    "        train(self.net, self.trainloader, epochs=1)\n",
    "        return get_parameters(self.net), len(self.trainloader), {}\n",
    "\n",
    "    def evaluate(self, parameters, config):\n",
    "        print(f\"[Client {self.partition_id}] evaluate, config: {config}\")\n",
    "        set_parameters(self.net, parameters)\n",
    "        loss, accuracy = test(self.net, self.valloader)\n",
    "        return float(loss), len(self.valloader), {\"accuracy\": float(accuracy)}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, we define the function `numpyclient_fn` that is used by Flower to create the `FlowerNumpyClient` instances on demand. Finally, we create the `ClientApp` and pass the `numpyclient_fn` to it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def numpyclient_fn(context: Context) -> Client:\n",
    "    net = Net().to(DEVICE)\n",
    "    partition_id = context.node_config[\"partition-id\"]\n",
    "    num_partitions = context.node_config[\"num-partitions\"]\n",
    "    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)\n",
    "    return FlowerNumPyClient(partition_id, net, trainloader, valloader).to_client()\n",
    "\n",
    "\n",
    "# Create the ClientApp\n",
    "numpyclient = ClientApp(client_fn=numpyclient_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We've seen this before, there's nothing new so far. The only *tiny* difference compared to the previous notebook is naming, we've changed `FlowerClient` to `FlowerNumPyClient` and `client_fn` to `numpyclient_fn`. Next, we configure the number of federated learning rounds using `ServerConfig` and create the `ServerApp` with this config:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def server_fn(context: Context) -> ServerAppComponents:\n",
    "    # Configure the server for 3 rounds of training\n",
    "    config = ServerConfig(num_rounds=3)\n",
    "    return ServerAppComponents(config=config)\n",
    "\n",
    "\n",
    "# Create ServerApp\n",
    "server = ServerApp(server_fn=server_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we specify the resources for each client and run the simulation to see the output we get:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Specify the resources each of your clients need\n",
    "# If set to none, by default, each client will be allocated 2x CPU and 0x GPUs\n",
    "backend_config = {\"client_resources\": None}\n",
    "if DEVICE.type == \"cuda\":\n",
    "    backend_config = {\"client_resources\": {\"num_gpus\": 1}}\n",
    "\n",
    "NUM_PARTITIONS = 10\n",
    "\n",
    "# Run simulation\n",
    "run_simulation(\n",
    "    server_app=server,\n",
    "    client_app=numpyclient,\n",
    "    num_supernodes=NUM_PARTITIONS,\n",
    "    backend_config=backend_config,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This works as expected, ten clients are training for three rounds of federated learning.\n",
    "\n",
    "Let's dive a little bit deeper and discuss how Flower executes this simulation. Whenever a client is selected to do some work, `run_simulation` launches the `ClientApp` object which in turn calls the function `numpyclient_fn` to create an instance of our `FlowerNumPyClient` (along with loading the model and the data).\n",
    "\n",
    "But here's the perhaps surprising part: Flower doesn't actually use the `FlowerNumPyClient` object directly. Instead, it wraps the object to makes it look like a subclass of `flwr.client.Client`, not `flwr.client.NumPyClient`. In fact, the Flower core framework doesn't know how to handle `NumPyClient`'s, it only knows how to handle `Client`'s. `NumPyClient` is just a convenience abstraction built on top of `Client`. \n",
    "\n",
    "Instead of building on top of `NumPyClient`, we can directly build on top of `Client`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Moving from `NumPyClient` to `Client`\n",
    "\n",
    "Let's try to do the same thing using `Client` instead of `NumPyClient`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from flwr.common import (\n",
    "    Code,\n",
    "    EvaluateIns,\n",
    "    EvaluateRes,\n",
    "    FitIns,\n",
    "    FitRes,\n",
    "    GetParametersIns,\n",
    "    GetParametersRes,\n",
    "    Status,\n",
    "    ndarrays_to_parameters,\n",
    "    parameters_to_ndarrays,\n",
    ")\n",
    "\n",
    "\n",
    "class FlowerClient(Client):\n",
    "    def __init__(self, partition_id, net, trainloader, valloader):\n",
    "        self.partition_id = partition_id\n",
    "        self.net = net\n",
    "        self.trainloader = trainloader\n",
    "        self.valloader = valloader\n",
    "\n",
    "    def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:\n",
    "        print(f\"[Client {self.partition_id}] get_parameters\")\n",
    "\n",
    "        # Get parameters as a list of NumPy ndarray's\n",
    "        ndarrays: List[np.ndarray] = get_parameters(self.net)\n",
    "\n",
    "        # Serialize ndarray's into a Parameters object\n",
    "        parameters = ndarrays_to_parameters(ndarrays)\n",
    "\n",
    "        # Build and return response\n",
    "        status = Status(code=Code.OK, message=\"Success\")\n",
    "        return GetParametersRes(\n",
    "            status=status,\n",
    "            parameters=parameters,\n",
    "        )\n",
    "\n",
    "    def fit(self, ins: FitIns) -> FitRes:\n",
    "        print(f\"[Client {self.partition_id}] fit, config: {ins.config}\")\n",
    "\n",
    "        # Deserialize parameters to NumPy ndarray's\n",
    "        parameters_original = ins.parameters\n",
    "        ndarrays_original = parameters_to_ndarrays(parameters_original)\n",
    "\n",
    "        # Update local model, train, get updated parameters\n",
    "        set_parameters(self.net, ndarrays_original)\n",
    "        train(self.net, self.trainloader, epochs=1)\n",
    "        ndarrays_updated = get_parameters(self.net)\n",
    "\n",
    "        # Serialize ndarray's into a Parameters object\n",
    "        parameters_updated = ndarrays_to_parameters(ndarrays_updated)\n",
    "\n",
    "        # Build and return response\n",
    "        status = Status(code=Code.OK, message=\"Success\")\n",
    "        return FitRes(\n",
    "            status=status,\n",
    "            parameters=parameters_updated,\n",
    "            num_examples=len(self.trainloader),\n",
    "            metrics={},\n",
    "        )\n",
    "\n",
    "    def evaluate(self, ins: EvaluateIns) -> EvaluateRes:\n",
    "        print(f\"[Client {self.partition_id}] evaluate, config: {ins.config}\")\n",
    "\n",
    "        # Deserialize parameters to NumPy ndarray's\n",
    "        parameters_original = ins.parameters\n",
    "        ndarrays_original = parameters_to_ndarrays(parameters_original)\n",
    "\n",
    "        set_parameters(self.net, ndarrays_original)\n",
    "        loss, accuracy = test(self.net, self.valloader)\n",
    "\n",
    "        # Build and return response\n",
    "        status = Status(code=Code.OK, message=\"Success\")\n",
    "        return EvaluateRes(\n",
    "            status=status,\n",
    "            loss=float(loss),\n",
    "            num_examples=len(self.valloader),\n",
    "            metrics={\"accuracy\": float(accuracy)},\n",
    "        )\n",
    "\n",
    "\n",
    "def client_fn(context: Context) -> Client:\n",
    "    net = Net().to(DEVICE)\n",
    "    partition_id = context.node_config[\"partition-id\"]\n",
    "    num_partitions = context.node_config[\"num-partitions\"]\n",
    "    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)\n",
    "    return FlowerClient(partition_id, net, trainloader, valloader).to_client()\n",
    "\n",
    "\n",
    "# Create the ClientApp\n",
    "client = ClientApp(client_fn=client_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before we discuss the code in more detail, let's try to run it! Gotta make sure our new `Client`-based client works, right?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run simulation\n",
    "run_simulation(\n",
    "    server_app=server,\n",
    "    client_app=client,\n",
    "    num_supernodes=NUM_PARTITIONS,\n",
    "    backend_config=backend_config,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That's it, we're now using `Client`. It probably looks similar to what we've done with `NumPyClient`. So what's the difference?\n",
    "\n",
    "First of all, it's more code. But why? The difference comes from the fact that `Client` expects us to take care of parameter serialization and deserialization. For Flower to be able to send parameters over the network, it eventually needs to turn these parameters into `bytes`. Turning parameters (e.g., NumPy `ndarray`'s) into raw bytes is called serialization. Turning raw bytes into something more useful (like NumPy `ndarray`'s) is called deserialization. Flower needs to do both: it needs to serialize parameters on the server-side and send them to the client, the client needs to deserialize them to use them for local training, and then serialize the updated parameters again to send them back to the server, which (finally!) deserializes them again in order to aggregate them with the updates received from other clients.\n",
    "\n",
    "The only *real* difference between Client and NumPyClient is that NumPyClient takes care of serialization and deserialization for you. It can do so because it expects you to return parameters as NumPy ndarray's, and it knows how to handle these. This makes working with machine learning libraries that have good NumPy support (most of them) a breeze.\n",
    "\n",
    "In terms of API, there's one major difference: all methods in Client take exactly one argument (e.g., `FitIns` in `Client.fit`) and return exactly one value (e.g., `FitRes` in `Client.fit`). The methods in `NumPyClient` on the other hand have multiple arguments (e.g., `parameters` and `config` in `NumPyClient.fit`) and multiple return values (e.g., `parameters`, `num_example`, and `metrics` in `NumPyClient.fit`) if there are multiple things to handle. These `*Ins` and `*Res` objects in `Client` wrap all the individual values you're used to from `NumPyClient`."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Custom serialization\n",
    "\n",
    "Here we will explore how to implement custom serialization with a simple example.\n",
    "\n",
    "But first what is serialization? Serialization is just the process of converting an object into raw bytes, and equally as important,\n",
    "deserialization is the process of converting raw bytes back into an object. This is very useful for network communication.\n",
    "Indeed, without serialization, you could not just a Python object through the internet.\n",
    "\n",
    "Federated Learning relies heavily on internet communication for training by sending Python objects back and forth between the clients and \n",
    "the server. This means that serialization is an essential part of Federated Learning.\n",
    "\n",
    "In the following section, we will write a basic example where instead of sending a serialized version of our `ndarray`s containing our parameters,\n",
    "we will first convert the `ndarray` into sparse matrices, before sending them. This technique can be used to save bandwidth, as in certain cases\n",
    "where the weights of a model are sparse (containing many 0 entries), converting them to a sparse matrix can greatly improve their bytesize.\n",
    "\n",
    "### Our custom serialization/deserialization functions\n",
    "\n",
    "This is where the real serialization/deserialization will happen, especially in `ndarray_to_sparse_bytes` for serialization and \n",
    "`sparse_bytes_to_ndarray` for deserialization.\n",
    "\n",
    "Note that we imported the `scipy.sparse` library in order to convert our arrays."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from io import BytesIO\n",
    "from typing import cast\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from flwr.common.typing import NDArray, NDArrays, Parameters\n",
    "\n",
    "\n",
    "def ndarrays_to_sparse_parameters(ndarrays: NDArrays) -> Parameters:\n",
    "    \"\"\"Convert NumPy ndarrays to parameters object.\"\"\"\n",
    "    tensors = [ndarray_to_sparse_bytes(ndarray) for ndarray in ndarrays]\n",
    "    return Parameters(tensors=tensors, tensor_type=\"numpy.ndarray\")\n",
    "\n",
    "\n",
    "def sparse_parameters_to_ndarrays(parameters: Parameters) -> NDArrays:\n",
    "    \"\"\"Convert parameters object to NumPy ndarrays.\"\"\"\n",
    "    return [sparse_bytes_to_ndarray(tensor) for tensor in parameters.tensors]\n",
    "\n",
    "\n",
    "def ndarray_to_sparse_bytes(ndarray: NDArray) -> bytes:\n",
    "    \"\"\"Serialize NumPy ndarray to bytes.\"\"\"\n",
    "    bytes_io = BytesIO()\n",
    "\n",
    "    if len(ndarray.shape) > 1:\n",
    "        # We convert our ndarray into a sparse matrix\n",
    "        ndarray = torch.tensor(ndarray).to_sparse_csr()\n",
    "\n",
    "        # And send it byutilizing the sparse matrix attributes\n",
    "        # WARNING: NEVER set allow_pickle to true.\n",
    "        # Reason: loading pickled data can execute arbitrary code\n",
    "        # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html\n",
    "        np.savez(\n",
    "            bytes_io,  # type: ignore\n",
    "            crow_indices=ndarray.crow_indices(),\n",
    "            col_indices=ndarray.col_indices(),\n",
    "            values=ndarray.values(),\n",
    "            allow_pickle=False,\n",
    "        )\n",
    "    else:\n",
    "        # WARNING: NEVER set allow_pickle to true.\n",
    "        # Reason: loading pickled data can execute arbitrary code\n",
    "        # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html\n",
    "        np.save(bytes_io, ndarray, allow_pickle=False)\n",
    "    return bytes_io.getvalue()\n",
    "\n",
    "\n",
    "def sparse_bytes_to_ndarray(tensor: bytes) -> NDArray:\n",
    "    \"\"\"Deserialize NumPy ndarray from bytes.\"\"\"\n",
    "    bytes_io = BytesIO(tensor)\n",
    "    # WARNING: NEVER set allow_pickle to true.\n",
    "    # Reason: loading pickled data can execute arbitrary code\n",
    "    # Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html\n",
    "    loader = np.load(bytes_io, allow_pickle=False)  # type: ignore\n",
    "\n",
    "    if \"crow_indices\" in loader:\n",
    "        # We convert our sparse matrix back to a ndarray, using the attributes we sent\n",
    "        ndarray_deserialized = (\n",
    "            torch.sparse_csr_tensor(\n",
    "                crow_indices=loader[\"crow_indices\"],\n",
    "                col_indices=loader[\"col_indices\"],\n",
    "                values=loader[\"values\"],\n",
    "            )\n",
    "            .to_dense()\n",
    "            .numpy()\n",
    "        )\n",
    "    else:\n",
    "        ndarray_deserialized = loader\n",
    "    return cast(NDArray, ndarray_deserialized)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Client-side\n",
    "\n",
    "To be able to serialize our `ndarray`s into sparse parameters, we will just have to call our custom functions in our `flwr.client.Client`.\n",
    "\n",
    "Indeed, in `get_parameters` we need to serialize the parameters we got from our network using our custom `ndarrays_to_sparse_parameters` defined above.\n",
    "\n",
    "In `fit`, we first need to deserialize the parameters coming from the server using our custom `sparse_parameters_to_ndarrays` and then we need to \n",
    "serialize our local results with `ndarrays_to_sparse_parameters`.\n",
    "\n",
    "In `evaluate`, we will only need to deserialize the global parameters with our custom function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from flwr.common import (\n",
    "    Code,\n",
    "    EvaluateIns,\n",
    "    EvaluateRes,\n",
    "    FitIns,\n",
    "    FitRes,\n",
    "    GetParametersIns,\n",
    "    GetParametersRes,\n",
    "    Status,\n",
    ")\n",
    "\n",
    "\n",
    "class FlowerClient(Client):\n",
    "    def __init__(self, partition_id, net, trainloader, valloader):\n",
    "        self.partition_id = partition_id\n",
    "        self.net = net\n",
    "        self.trainloader = trainloader\n",
    "        self.valloader = valloader\n",
    "\n",
    "    def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:\n",
    "        print(f\"[Client {self.partition_id}] get_parameters\")\n",
    "\n",
    "        # Get parameters as a list of NumPy ndarray's\n",
    "        ndarrays: List[np.ndarray] = get_parameters(self.net)\n",
    "\n",
    "        # Serialize ndarray's into a Parameters object using our custom function\n",
    "        parameters = ndarrays_to_sparse_parameters(ndarrays)\n",
    "\n",
    "        # Build and return response\n",
    "        status = Status(code=Code.OK, message=\"Success\")\n",
    "        return GetParametersRes(\n",
    "            status=status,\n",
    "            parameters=parameters,\n",
    "        )\n",
    "\n",
    "    def fit(self, ins: FitIns) -> FitRes:\n",
    "        print(f\"[Client {self.partition_id}] fit, config: {ins.config}\")\n",
    "\n",
    "        # Deserialize parameters to NumPy ndarray's using our custom function\n",
    "        parameters_original = ins.parameters\n",
    "        ndarrays_original = sparse_parameters_to_ndarrays(parameters_original)\n",
    "\n",
    "        # Update local model, train, get updated parameters\n",
    "        set_parameters(self.net, ndarrays_original)\n",
    "        train(self.net, self.trainloader, epochs=1)\n",
    "        ndarrays_updated = get_parameters(self.net)\n",
    "\n",
    "        # Serialize ndarray's into a Parameters object using our custom function\n",
    "        parameters_updated = ndarrays_to_sparse_parameters(ndarrays_updated)\n",
    "\n",
    "        # Build and return response\n",
    "        status = Status(code=Code.OK, message=\"Success\")\n",
    "        return FitRes(\n",
    "            status=status,\n",
    "            parameters=parameters_updated,\n",
    "            num_examples=len(self.trainloader),\n",
    "            metrics={},\n",
    "        )\n",
    "\n",
    "    def evaluate(self, ins: EvaluateIns) -> EvaluateRes:\n",
    "        print(f\"[Client {self.partition_id}] evaluate, config: {ins.config}\")\n",
    "\n",
    "        # Deserialize parameters to NumPy ndarray's using our custom function\n",
    "        parameters_original = ins.parameters\n",
    "        ndarrays_original = sparse_parameters_to_ndarrays(parameters_original)\n",
    "\n",
    "        set_parameters(self.net, ndarrays_original)\n",
    "        loss, accuracy = test(self.net, self.valloader)\n",
    "\n",
    "        # Build and return response\n",
    "        status = Status(code=Code.OK, message=\"Success\")\n",
    "        return EvaluateRes(\n",
    "            status=status,\n",
    "            loss=float(loss),\n",
    "            num_examples=len(self.valloader),\n",
    "            metrics={\"accuracy\": float(accuracy)},\n",
    "        )\n",
    "\n",
    "\n",
    "def client_fn(context: Context) -> Client:\n",
    "    net = Net().to(DEVICE)\n",
    "    partition_id = context.node_config[\"partition-id\"]\n",
    "    num_partitions = context.node_config[\"num-partitions\"]\n",
    "    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)\n",
    "    return FlowerClient(partition_id, net, trainloader, valloader).to_client()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Server-side\n",
    "\n",
    "For this example, we will just use `FedAvg` as a strategy. \n",
    "To change the serialization and deserialization here, we only need to reimplement the `evaluate` and `aggregate_fit` functions of `FedAvg`.\n",
    "The other functions of the strategy will be inherited from the super class `FedAvg`.\n",
    "\n",
    "As you can see only one line as change in `evaluate`:\n",
    "\n",
    "```python\n",
    "parameters_ndarrays = sparse_parameters_to_ndarrays(parameters)\n",
    "```\n",
    "\n",
    "And for `aggregate_fit`, we will first deserialize every result we received:\n",
    "\n",
    "```python\n",
    "weights_results = [\n",
    "    (sparse_parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)\n",
    "    for _, fit_res in results\n",
    "]\n",
    "```\n",
    "\n",
    "And then serialize the aggregated result:\n",
    "\n",
    "```python\n",
    "parameters_aggregated = ndarrays_to_sparse_parameters(aggregate(weights_results))\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from logging import WARNING\n",
    "from typing import Callable, Dict, List, Optional, Tuple, Union\n",
    "\n",
    "from flwr.common import FitRes, MetricsAggregationFn, NDArrays, Parameters, Scalar\n",
    "from flwr.common.logger import log\n",
    "from flwr.server.client_proxy import ClientProxy\n",
    "from flwr.server.strategy import FedAvg\n",
    "from flwr.server.strategy.aggregate import aggregate\n",
    "\n",
    "WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = \"\"\"\n",
    "Setting `min_available_clients` lower than `min_fit_clients` or\n",
    "`min_evaluate_clients` can cause the server to fail when there are too few clients\n",
    "connected to the server. `min_available_clients` must be set to a value larger\n",
    "than or equal to the values of `min_fit_clients` and `min_evaluate_clients`.\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "class FedSparse(FedAvg):\n",
    "    def __init__(\n",
    "        self,\n",
    "        *,\n",
    "        fraction_fit: float = 1.0,\n",
    "        fraction_evaluate: float = 1.0,\n",
    "        min_fit_clients: int = 2,\n",
    "        min_evaluate_clients: int = 2,\n",
    "        min_available_clients: int = 2,\n",
    "        evaluate_fn: Optional[\n",
    "            Callable[\n",
    "                [int, NDArrays, Dict[str, Scalar]],\n",
    "                Optional[Tuple[float, Dict[str, Scalar]]],\n",
    "            ]\n",
    "        ] = None,\n",
    "        on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,\n",
    "        on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,\n",
    "        accept_failures: bool = True,\n",
    "        initial_parameters: Optional[Parameters] = None,\n",
    "        fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,\n",
    "        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,\n",
    "    ) -> None:\n",
    "        \"\"\"Custom FedAvg strategy with sparse matrices.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        fraction_fit : float, optional\n",
    "            Fraction of clients used during training. Defaults to 0.1.\n",
    "        fraction_evaluate : float, optional\n",
    "            Fraction of clients used during validation. Defaults to 0.1.\n",
    "        min_fit_clients : int, optional\n",
    "            Minimum number of clients used during training. Defaults to 2.\n",
    "        min_evaluate_clients : int, optional\n",
    "            Minimum number of clients used during validation. Defaults to 2.\n",
    "        min_available_clients : int, optional\n",
    "            Minimum number of total clients in the system. Defaults to 2.\n",
    "        evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]]\n",
    "            Optional function used for validation. Defaults to None.\n",
    "        on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional\n",
    "            Function used to configure training. Defaults to None.\n",
    "        on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional\n",
    "            Function used to configure validation. Defaults to None.\n",
    "        accept_failures : bool, optional\n",
    "            Whether or not accept rounds containing failures. Defaults to True.\n",
    "        initial_parameters : Parameters, optional\n",
    "            Initial global model parameters.\n",
    "        \"\"\"\n",
    "\n",
    "        if (\n",
    "            min_fit_clients > min_available_clients\n",
    "            or min_evaluate_clients > min_available_clients\n",
    "        ):\n",
    "            log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW)\n",
    "\n",
    "        super().__init__(\n",
    "            fraction_fit=fraction_fit,\n",
    "            fraction_evaluate=fraction_evaluate,\n",
    "            min_fit_clients=min_fit_clients,\n",
    "            min_evaluate_clients=min_evaluate_clients,\n",
    "            min_available_clients=min_available_clients,\n",
    "            evaluate_fn=evaluate_fn,\n",
    "            on_fit_config_fn=on_fit_config_fn,\n",
    "            on_evaluate_config_fn=on_evaluate_config_fn,\n",
    "            accept_failures=accept_failures,\n",
    "            initial_parameters=initial_parameters,\n",
    "            fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,\n",
    "            evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,\n",
    "        )\n",
    "\n",
    "    def evaluate(\n",
    "        self, server_round: int, parameters: Parameters\n",
    "    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:\n",
    "        \"\"\"Evaluate model parameters using an evaluation function.\"\"\"\n",
    "        if self.evaluate_fn is None:\n",
    "            # No evaluation function provided\n",
    "            return None\n",
    "\n",
    "        # We deserialize using our custom method\n",
    "        parameters_ndarrays = sparse_parameters_to_ndarrays(parameters)\n",
    "\n",
    "        eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})\n",
    "        if eval_res is None:\n",
    "            return None\n",
    "        loss, metrics = eval_res\n",
    "        return loss, metrics\n",
    "\n",
    "    def aggregate_fit(\n",
    "        self,\n",
    "        server_round: int,\n",
    "        results: List[Tuple[ClientProxy, FitRes]],\n",
    "        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],\n",
    "    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:\n",
    "        \"\"\"Aggregate fit results using weighted average.\"\"\"\n",
    "        if not results:\n",
    "            return None, {}\n",
    "        # Do not aggregate if there are failures and failures are not accepted\n",
    "        if not self.accept_failures and failures:\n",
    "            return None, {}\n",
    "\n",
    "        # We deserialize each of the results with our custom method\n",
    "        weights_results = [\n",
    "            (sparse_parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)\n",
    "            for _, fit_res in results\n",
    "        ]\n",
    "\n",
    "        # We serialize the aggregated result using our custom method\n",
    "        parameters_aggregated = ndarrays_to_sparse_parameters(\n",
    "            aggregate(weights_results)\n",
    "        )\n",
    "\n",
    "        # Aggregate custom metrics if aggregation fn was provided\n",
    "        metrics_aggregated = {}\n",
    "        if self.fit_metrics_aggregation_fn:\n",
    "            fit_metrics = [(res.num_examples, res.metrics) for _, res in results]\n",
    "            metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)\n",
    "        elif server_round == 1:  # Only log this warning once\n",
    "            log(WARNING, \"No fit_metrics_aggregation_fn provided\")\n",
    "\n",
    "        return parameters_aggregated, metrics_aggregated"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now run our custom serialization example!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def server_fn(context: Context) -> ServerAppComponents:\n",
    "    # Configure the server for just 3 rounds of training\n",
    "    config = ServerConfig(num_rounds=3)\n",
    "    return ServerAppComponents(\n",
    "        config=config,\n",
    "        strategy=FedSparse(),  # <-- pass the new strategy here\n",
    "    )\n",
    "\n",
    "\n",
    "# Create the ServerApp\n",
    "server = ServerApp(server_fn=server_fn)\n",
    "\n",
    "# Run simulation\n",
    "run_simulation(\n",
    "    server_app=server,\n",
    "    client_app=client,\n",
    "    num_supernodes=NUM_PARTITIONS,\n",
    "    backend_config=backend_config,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Recap\n",
    "\n",
    "In this part of the tutorial, we've seen how we can build clients by subclassing either `NumPyClient` or `Client`. `NumPyClient` is a convenience abstraction that makes it easier to work with machine learning libraries that have good NumPy interoperability. `Client` is a more flexible abstraction that allows us to do things that are not possible in `NumPyClient`. In order to do so, it requires us to handle parameter serialization and deserialization ourselves."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Next steps\n",
    "\n",
    "Before you continue, make sure to join the Flower community on Flower Discuss ([Join Flower Discuss](https://discuss.flower.ai)) and on Slack ([Join Slack](https://flower.ai/join-slack/)).\n",
    "\n",
    "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n",
    "\n",
    "This is the final part of the Flower tutorial (for now!), congratulations! You're now well equipped to understand the rest of the documentation. There are many topics we didn't cover in the tutorial, we recommend the following resources:\n",
    "\n",
    "- [Read Flower Docs](https://flower.ai/docs/)\n",
    "- [Check out Flower Code Examples](https://flower.ai/docs/examples/)\n",
    "- [Use Flower Baselines for your research](https://flower.ai/docs/baselines/)\n",
    "- [Watch Flower AI Summit 2024 videos](https://flower.ai/conf/flower-ai-summit-2024/)\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "Flower-4-Client-and-NumPyClient-PyTorch.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "flower-3.7.12",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
